// Package client 包含了与app端通讯的接口事务处理。
package client

import (
	"acs/comet/proto"
	"acs/pbmodel"
	"acs/util"
	"bufio"
	"errors"
	"fmt"
	"net"
	"runtime/debug"
	"sync"
	"time"

	log "github.com/cihub/seelog"
	pb "github.com/golang/protobuf/proto"
)

var (
	ErrorInvalidMessage = errors.New("Invalid message to send, the message should either be *Reqeust or *Response.")
	ErrorClientClosing  = errors.New("Client is closing.")
	ErrClientConnBroken = errors.New("Client connection broken")
)

var logger = log.Default
var setMutex = new(sync.Mutex)

const (
	MAX_TRANSACTION_ID = 2 ^ 14 - 1
	// 握手超时时间。单位: 毫秒
	DEFAULT_HANDSHAKE_TIMEOUT = 1000
)

// SetLogger sets the logger for taskprocess package.
// seelog.Default it the default logger.
func SetLogger(newLogger log.LoggerInterface) {
	setMutex.Lock()
	defer setMutex.Unlock()
	logger = newLogger
}

type ClientConfig struct {
	Conn              net.Conn
	ExitWaitGroup     *sync.WaitGroup
	MaxReqRespTimeout uint32
	HandshakeTimeout  uint32
	DataEncryptKey    string
	DataEncryptIV     string
}
type RequestResponse struct {
	Err      error
	RespData pb.Message
}

// RequestInfo 发送到对端的请求(cmd)信息,用户回包/回调处理
type RequestInfo struct {
	// Unix timestamp nanoseconds
	sendTime int64
	// 用于异步请求的回调. 返回的channel应当能读取到值true, 以便上层调用判断是否处理完毕.
	callback func(*RequestResponse)
	// 对端返回的数据
	respData pb.Message
	// 返回数据响应对象的channel
	reqRespChan chan *RequestResponse
}

// requestInfoList 保存本端请求信息的列表,用于对端的返回处理。
type requestInfoList struct {
	items  map[uint16]*RequestInfo
	rwlock *sync.RWMutex
}

type notifyEvent string

const (
	NotifyEventNewPatch  = notifyEvent("new_patch")
	NotifyEventNewSchema = notifyEvent("new_schema")
)

// Client 客户端管理容器,包含连接管理、客户端信息处理等。
type Client struct {
	closing            bool
	closingLock        *sync.RWMutex
	regInfoLock        *sync.RWMutex
	buf                chan []byte
	transIdCounter     uint16
	transIdCounterLock *sync.RWMutex
	Reader             *bufio.Reader
	Writer             *bufio.Writer
	ConnErr            chan error
	reqInfoList        *requestInfoList
	conf               *ClientConfig
	registerInfo       *pbmodel.RegisterInfo
	ConnectTime        int64
	LastRegTime        int64
	// 完成注册后handshake成功
	handshaked bool
}

func (l *requestInfoList) Add(transId uint16, rqi *RequestInfo) error {
	l.rwlock.Lock()
	defer l.rwlock.Unlock()
	if _, exits := l.items[transId]; exits {
		return errors.New(fmt.Sprintf("Request transaction ID %v already exists in the list!", transId))
	}
	l.items[transId] = rqi
	return nil
}

func (l *requestInfoList) Consume(transId uint16) (*RequestInfo, error) {
	l.rwlock.Lock()
	defer l.rwlock.Unlock()
	if _, exits := l.items[transId]; !exits {
		return nil, errors.New(fmt.Sprintf("Request transaction ID %v not exists in the list!", transId))
	}
	i := l.items[transId]
	delete(l.items, transId)
	return i, nil
}

func (l *requestInfoList) Exits(transId uint16) bool {
	l.rwlock.RLock()
	defer l.rwlock.RUnlock()
	_, exits := l.items[transId]
	return exits
}

func NewClient(conf *ClientConfig) *Client {
	client := &Client{
		closingLock:        &sync.RWMutex{},
		regInfoLock:        &sync.RWMutex{},
		buf:                make(chan []byte, 1),
		Reader:             bufio.NewReader(conf.Conn),
		Writer:             bufio.NewWriter(conf.Conn),
		ConnErr:            make(chan error, 10),
		transIdCounterLock: &sync.RWMutex{},
		reqInfoList: &requestInfoList{
			items:  map[uint16]*RequestInfo{},
			rwlock: &sync.RWMutex{},
		},
		conf:        conf,
		ConnectTime: time.Now().Unix(),
	}
	if client.conf.HandshakeTimeout == 0 {
		client.conf.HandshakeTimeout = DEFAULT_HANDSHAKE_TIMEOUT
	}
	go client.checkHandshake()
	return client
}

func (this *Client) IsClosing() bool {
	this.closingLock.Lock()
	defer this.closingLock.Unlock()
	return this.closing == true
}

// checkHandshake 检查是否执行了握手, 握手超时则关闭会话.
func (this *Client) checkHandshake() {
	tv := time.Millisecond * time.Duration(this.conf.HandshakeTimeout)
	timeout := time.After(tv)
	for {
		select {
		case <-timeout:
			logger.Warnf("[remote %v] Client handshake timedout after [%v]. closing session...", this.conf.Conn.RemoteAddr(), tv)
			this.Close()
			return
		default:
			if this.IsHandshaked() {
				return
			}
			time.Sleep(time.Millisecond * 20)
		}
	}
}

// PushMessage 发送数据至对端(Request、Response、ProtoError).
func (c *Client) PushMessage(packet interface{}) (err error) {
	defer func() {
		errFatal := recover()
		if errFatal != nil {
			util.PrintPanicStack()
			err = fmt.Errorf("[remote %v]PushMessage fata Error: %v\n%s", c.conf.Conn.RemoteAddr(), errFatal, debug.Stack())
		}
	}()
	c.closingLock.RLock()
	defer c.closingLock.RUnlock()
	if c.closing {
		return ErrorClientClosing
	}

	logger.Debugf("[remote %v]Send data to: %+v", c.conf.Conn.RemoteAddr(), packet)
	switch packet.(type) {
	case *proto.Request:
		c.buf <- packet.(*proto.Request).Encode()
	case *proto.Response:
		c.buf <- packet.(*proto.Response).Encode()
	case *proto.ProtoError:
		c.buf <- packet.(*proto.ProtoError).Encode()
	default:
		logger.Errorf("%v", ErrorInvalidMessage)
		return ErrorInvalidMessage
	}
	return nil
}

func (this *Client) HandleWrite() {
	go func() {
		for {
			message, ok := <-this.buf
			if !ok {
				logger.Infof("[remote %v]client handle routine stop", this.conf.Conn.RemoteAddr())
				return
			}
			_, err := this.conf.Conn.Write(message)
			if err != nil {
				this.ConnErr <- ErrClientConnBroken
				logger.Warnf("[remote %v] conn Write error: %v", this.conf.Conn.RemoteAddr(), err)
			}
		}
	}()
}

func (this *Client) Close() error {
	this.closingLock.Lock()
	defer this.closingLock.Unlock()
	// close only once
	if this.closing {
		return nil
	}
	this.closing = true
	close(this.buf)
	return this.conf.Conn.Close()
}

func (this *Client) genTransactionID() uint16 {
	this.transIdCounterLock.Lock()
	defer this.transIdCounterLock.Unlock()
	this.transIdCounter = this.transIdCounter%MAX_TRANSACTION_ID + 1
	return this.transIdCounter
}

func (this *Client) GetRegisterInfo() pbmodel.RegisterInfo {
	this.regInfoLock.RLock()
	defer this.regInfoLock.RUnlock()
	if this.registerInfo == nil {
		return pbmodel.RegisterInfo{}
	}
	return *this.registerInfo
}

func (this *Client) SetRegisterInfo(info pbmodel.RegisterInfo) {
	this.regInfoLock.Lock()
	defer this.regInfoLock.Unlock()
	this.registerInfo = &info
}

func (this *Client) GetRemoteAddr() string {
	return this.conf.Conn.RemoteAddr().String()
}

func (this *Client) SetHandshaked(v bool) {
	this.regInfoLock.Lock()
	defer this.regInfoLock.Unlock()
	this.handshaked = v
}

func (this *Client) IsHandshaked() bool {
	this.regInfoLock.RLock()
	defer this.regInfoLock.RUnlock()
	v := this.handshaked
	return v
}

// SendCmdAsync 发送请求到对端,不等待返回,返回结果由callback处理.
// cmdData 为协议中定义的(protobuf)对象.
// callback 为处理对端响应的方法,其参数respData为协议中定义的(protobuf)对象.
func (this *Client) SendCmdAsync(cmd string, cmdData pb.Message, respData pb.Message, callback func(reqResp *RequestResponse)) (err error) {
	var data []byte
	data, err = pb.Marshal(cmdData)
	if err != nil {
		return
	}
	req := proto.NewRequest(this.genTransactionID(), cmd, data, this.conf.DataEncryptKey, this.conf.DataEncryptIV)
	err = this.reqInfoList.Add(req.Tsid, &RequestInfo{time.Now().UnixNano(), callback, respData, nil})
	if err != nil {
		return
	}

	this.PushMessage(req)
	return
}

// SendCmd 发送请求到对端,并获取响应结果.
// cmdData 为协议中定义的(protobuf)对象.
// respData为协议中定义的(protobuf)对象
func (this *Client) SendCmd(cmd string, cmdData pb.Message, respData pb.Message) (reqResp *RequestResponse) {
	reqResp = new(RequestResponse)
	data, err := pb.Marshal(cmdData)
	if err != nil {
		reqResp.Err = err
		return
	}
	req := proto.NewRequest(this.genTransactionID(), cmd, data, this.conf.DataEncryptKey, this.conf.DataEncryptIV)
	respChan := make(chan *RequestResponse, 1)
	err = this.reqInfoList.Add(req.Tsid, &RequestInfo{time.Now().UnixNano(), nil, respData, respChan})
	if err != nil {
		reqResp.Err = err
		return
	}
	err = this.PushMessage(req)
	if err != nil {
		reqResp.Err = err
		return
	}
	tv := time.Millisecond * time.Duration(this.conf.MaxReqRespTimeout)
	timeout := time.After(tv)
	select {
	case reqResp = <-respChan:
	case <-timeout:
		reqResp.Err = fmt.Errorf("[remote %v]cmd [%v] to client execute timedout: [%v]", this.conf.Conn.RemoteAddr(), cmd, tv)
	}
	return
}

func (c *Client) HandleRemoteProtoError(resp *proto.ProtoError) {
	c.conf.ExitWaitGroup.Add(1)
	defer func() {
		c.conf.ExitWaitGroup.Done()
	}()
	errInfo := new(pbmodel.ProtoErr)
	err := pb.Unmarshal(resp.Data, errInfo)
	if err != nil {
		logger.Warnf("ProtoError messge decode error: %v", err)
		return
	}
	logger.Warnf("[remote %v]Got proto error for transaction[%v]: %+v", c.conf.Conn.RemoteAddr(), resp.Tsid, errInfo)
}

func (c *Client) HandleProtoError(transId uint16, errCode int32, err error) {
	c.conf.ExitWaitGroup.Add(1)
	defer func() {
		c.conf.ExitWaitGroup.Done()
	}()
	errInfo := &pbmodel.ProtoErr{
		Code: pb.Int32(errCode),
		Msg:  pb.String(err.Error()),
	}

	data, err := pb.Marshal(errInfo)
	if err != nil {
		logger.Warnf("[remote %v]encoding pb protoErr error: %v", c.conf.Conn.RemoteAddr(), err)
		return
	}
	packet := proto.NewProtoError(transId, data, c.conf.DataEncryptKey, c.conf.DataEncryptIV)

	c.PushMessage(packet)
}

// HandleRequest 应当在单独的协程里运行
func (this *Client) HandleRequest(req *proto.Request) {
	this.conf.ExitWaitGroup.Add(1)
	logger.Debugf("[remote %v]Handling request: %+v", this.conf.Conn.RemoteAddr(), *req)
	defer func() {
		this.conf.ExitWaitGroup.Done()
		err := recover()
		if err != nil {
			util.PrintPanicStack()
			logger.Errorf("[remote %v]request handler exited unexpectedly -LocalAddr: %v [ %v ]: %s", this.conf.Conn.RemoteAddr(), this.conf.Conn.LocalAddr(), err, debug.Stack())
		} else {
			logger.Debugf("[remote %v]Handled request: %+v", this.conf.Conn.RemoteAddr(), *req)
		}
	}()
	if !this.IsHandshaked() && req.Cmd != proto.CmdRegister && req.Cmd != proto.CmdPing {
		logger.Warnf("Error: client commands[%v] before handshake!", req.Cmd)
		this.Close()
	}
	var reqHandle HandleFunc
	switch req.Cmd {
	case proto.CmdPing:
		reqHandle = HandlePing
	case proto.CmdRegister:
		reqHandle = HandleRegister
	case proto.CmdEvent:
		reqHandle = HandleEvent
	case proto.CmdPatch:
	case proto.CmdPathfin:
		reqHandle = HandlePatchFin
	case proto.CmdSchema:
	case proto.CmdSchemafin:
		reqHandle = HandleSchemaFin
	case proto.CmdForward:
		reqHandle = HandleForward
	default:
		reqHandle = HandleDefault
	}
	// TODO: clean timedout process chans
	processChan := make(chan *RequestResponse, 1)
	go func() {
		defer func() {
			err := recover()
			if err != nil {
				util.PrintPanicStack()
				logger.Errorf("[remote %v]request handler exited unexpectedly, LocalAddr: %v [ %v ]: %s", this.conf.Conn.RemoteAddr(), this.conf.Conn.LocalAddr(), err, debug.Stack())
			}
		}()
		m, err := reqHandle(req, this)
		processChan <- &RequestResponse{RespData: m, Err: err}
	}()
	tv := time.Millisecond * time.Duration(this.conf.MaxReqRespTimeout)
	timeout := time.After(tv)
	select {
	case <-timeout:
		logger.Warnf("[remote %v]response handler callback for transaction[%v] timed out [%v], LocalAddr: %v", this.conf.Conn.RemoteAddr(), req.Tsid, tv, this.conf.Conn.LocalAddr())
		break
	case result := <-processChan:
		logger.Debugf("[remote %v] requset process result: %+v", this.conf.Conn.RemoteAddr(), result)
		if result.Err != nil {
			logger.Infof("[remote %v]request process error: %v", this.conf.Conn.RemoteAddr(), result.Err)
		} else {
			data, err := pb.Marshal(result.RespData)
			if err != nil {
				logger.Infof("[remote %v]request process result encoding error: %v", this.conf.Conn.RemoteAddr(), err)
			} else {
				this.PushMessage(proto.NewResponse(req.Tsid, data, this.conf.DataEncryptKey, this.conf.DataEncryptIV))
			}
		}
	}
}

// handleResponse 应当在单独的协程里运行
func (this *Client) HandleResponse(resp *proto.Response) {
	this.conf.ExitWaitGroup.Add(1)
	defer func() {
		this.conf.ExitWaitGroup.Done()
		err := recover()
		if err != nil {
			util.PrintPanicStack()
			logger.Errorf("[remote %v]response handler exited unexpectedly, LocalAddr: %v [ %v ]: %s", this.conf.Conn.RemoteAddr(), this.conf.Conn.LocalAddr(), err, debug.Stack())
		}
	}()
	respInfo, err := this.reqInfoList.Consume(resp.Tsid)
	if err != nil {
		logger.Warnf("[remote %v]Failed to get callback for response[%+v]. Got error: %v", this.conf.Conn.RemoteAddr(), *resp, err)
		return
	}
	err = pb.Unmarshal(resp.Data, respInfo.respData)
	reqResp := &RequestResponse{err, respInfo.respData}
	if respInfo.callback != nil {
		go func() {
			this.conf.ExitWaitGroup.Add(1)
			defer func() {
				this.conf.ExitWaitGroup.Done()
				err := recover()
				if err != nil {
					util.PrintPanicStack()
					logger.Errorf("[remote %v]response handler callback exited unexpectedly, LocalAddr: %v [ %v ]: %s", this.conf.Conn.RemoteAddr(), this.conf.Conn.LocalAddr(), err, debug.Stack())
				}
			}()
			// TODO: clean timedout process channels
			processChan := make(chan bool, 1)
			go func() {
				respInfo.callback(reqResp)
				processChan <- true
			}()

			tv := time.Millisecond * time.Duration(this.conf.MaxReqRespTimeout)
			timeout := time.After(tv)
			select {
			case <-timeout:
				logger.Warnf("[remote %v]response handler callback for transaction[%v] timed out [%v], LocalAddr: %v ", this.conf.Conn.RemoteAddr(), resp.Tsid, tv, this.conf.Conn.LocalAddr())
			case <-processChan:
			}
		}()
	} else if respInfo.reqRespChan != nil {
		respInfo.reqRespChan <- reqResp
	} else {
		logger.Warnf("[remote %v]Error: callback and reqRespChan are nil, response cannot be handled:[%+v]", this.conf.Conn.RemoteAddr(), *reqResp)
	}
}
